# -*- coding: utf-8 -*-
import torch
from torch import nn
import copy
import torch.optim as optim
import torch.nn.functional as F
import yaml
from torch.utils.hooks import unserializable_hook
device = 'cuda' if torch.cuda.is_available() else 'cpu'

@unserializable_hook
def replace_nan_with_zero(grad):
    return torch.nan_to_num(grad, nan=0.0)

class Ddpg(nn.Module):
    def __init__(self, cfg):
        super(Ddpg, self).__init__()
        self.cfg = cfg

        self.policy = nn.Sequential(
            nn.Linear(self.cfg.state_dim + self.cfg.index_dim, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, self.cfg.state_dim)
        ).to(device)

        self.qvalue = nn.Sequential(
            nn.Linear(self.cfg.state_dim*2 + self.cfg.index_dim, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, self.cfg.hidden_size),
            nn.LeakyReLU(0.1),
            nn.Linear(self.cfg.hidden_size, 1)
        ).to(device)
        
        self.p_optimizer = optim.Adam(self.policy.parameters(), lr=cfg.agent_lr, weight_decay=1e-6)
        self.q_optimizer = optim.Adam(self.qvalue.parameters(), lr=cfg.agent_lr, weight_decay=1e-6)


        state_dict = torch.load('Parameter/initial_policy.pt', map_location=device, weights_only=True)
        new_state_dict = {key.replace("policy.", ""): value for key, value in state_dict.items()}
        self.policy.load_state_dict(new_state_dict)


    def forward(self, state_set, offset = None):
        xyz_raw = self.policy(state_set)


        xyz_action = F.normalize(xyz_raw, p=2, dim=-1, eps=1e-1)

        xyz_action = xyz_action - (xyz_raw**2)/10000
        xyz_action2 = 0
        scale = 0.05

        if offset == None:

            if self.cfg.algorithm == "diayn":
                offset = (torch.rand_like(xyz_action).to(device)*2-1)*scale
            else:
                offset = torch.randn_like(xyz_action).to(device)*scale
            xyz_action = xyz_action +  torch.clamp(offset, min = -(scale), max = (scale))
        
        else:
            pass
            xyz_action2 = xyz_action +  torch.clamp(offset, min = -(scale), max = (scale))


        xyz_action = xyz_action *10 / self.cfg.ep_len
        xyz_action2 = xyz_action2*10 / self.cfg.ep_len


        return xyz_action, xyz_action2, offset.clone()

    def policy_train(self, loss):
        self.policy.train()

        self.p_optimizer.zero_grad()
        loss.backward()

        self.p_optimizer.step()
        return loss.item()

    def queue_forward(self, traj, action, reward):
     
        input_state = traj[:-1]
        output_state = traj[1:-1]
        output_action = action[1:]

        reward_sum = torch.flip(torch.cumsum(torch.flip(reward, dims=[0]), dim=0), dims=[0])

        reward_sum = reward_sum.unsqueeze(-1).detach().clone()

        queue_input = torch.cat((input_state, action), dim = -1)
        queue_input2 = torch.cat((output_state, output_action), dim = -1)

        queue_input2 = torch.cat((queue_input2, queue_input2[0].unsqueeze(0)*0), dim = 0)

        done = torch.ones(len(queue_input2)).to(device)
        done[-1] = 0

        done = done.unsqueeze(-1).unsqueeze(-1)
 

        diff = self.qvalue(queue_input) - (reward_sum*0.01)

        loss = torch.mean(diff**2)

        return -torch.sum(self.qvalue(queue_input)), loss

    def queue_train(self, loss):
        self.qvalue.train()

        self.q_optimizer.zero_grad()
        loss.backward(retain_graph = True)

        self.q_optimizer.step()
        return loss.item()